import re
import pysam
from itertools import combinations


class Merge:
    """ 合并多个突变 """
    def __init__(self, fasta):
        self.fasta = fasta
        self.Fa = pysam.FastaFile(fasta)
        self.vars = []

    def 增加突变(self, in_var):
        """ 添加需要合并的单个突变，重复调用添加多个突变  """
        self.vars.append(Var(in_var, self.fasta))

    def _check_chr(self):
        """ 核对所有的突变是否在同一条染色体上 """
        Chrs = set(var.Chr for var in self.vars)
        if len(Chrs) > 1:
            msg = f'染色体不一致 {Chrs}'
            raise ValueError(msg)
    
    def _check_pos(self):
        """ 核对突变位置是否有交集 """
        for a, b in combinations(self.vars, 2):
            if a.with_same_pos(b):
                msg = f'两个突变存在共同位置，不能合并. {a} {b}'
                raise ValueError(msg)

    def _sort_var_by_pos(self):
        """ 根据位置顺序，排序变异 """
        # 插入和点图标的话，插入突变要在其后端
        pos_var = {}
        for var in self.vars:
            if var.Ref == '-':
                pos_var[f'{var.Start}_2ins'] = var
            else:
                pos_var[f'{var.Start}_1other'] = var 
        new_var_list = []
        for k in sorted([key for key in pos_var.keys()]):
            # print(k, pos_var[k])
            new_var_list.append(pos_var[k])
        self.vars = new_var_list

    def 合并(self) -> Var:
        """ 将输入的突变合并成新的复杂合并 """
        self._check_chr()
        self._check_pos()
        self._sort_var_by_pos()
        start_pos = self.vars[0].Start + 1 if self.vars[0].Ref == '-' else self.vars[0].Start
        # print(start_pos)
        ref_seq = self.Fa.fetch(region=f'{self.vars[0].Chr}:{start_pos}-{self.vars[-1].End}')
        base_dict = {p:b for p,b in zip(range(start_pos, self.vars[-1].End+1), ref_seq)}
        # print(base_dict)
        last_end = 0
        alt_seq = ''
        for n, var in enumerate(self.vars):
            if n > 0:
                end = var.Start + 1 if var.Ref == '-' else var.Start
                alt_seq += "".join(base_dict[i] for i in range(last_end+1, end))
            # print(var)
            alt_seq += var.Alt
            last_end = var.End
        # 去除缺失符号 -
        alt_seq = alt_seq.replace('-', '')
        if len(ref_seq) > len(alt_seq):
            loop = range(len(alt_seq))
        else:
            loop = range(len(ref_seq))
        # print(ref_seq)
        # print(alt_seq)
        # End 移动
        move = 0
        for i in loop:
            if ref_seq[::-1][i] != alt_seq[::-1][i]:
                break
            move += 1
        if move:
            ref_seq = ref_seq[:-move] if ref_seq[:-move] else '-'
            alt_seq = alt_seq[:-move] if alt_seq[:-move] else '-'
        End = self.vars[-1].End - move
        # Start 移动
        move = 0
        for i in loop:
            if ref_seq[i] != alt_seq[i]:
                break
            move += 1
        ref_seq = ref_seq[move:] if ref_seq[move:] else '-'
        alt_seq = alt_seq[move:] if alt_seq[move:] else '-'
        if ref_seq == '-':
            Start = start_pos+move - 1
            End = End + 1
        else:
            Start = start_pos+move
            End = End
        # print(f'{Start} {End} {ref_seq} {alt_seq}')
        o_var = f'{self.vars[0].Chr} {Start} {End} {ref_seq if len(ref_seq) < 10 else (ref_seq[:10] + "...")} {alt_seq if len(alt_seq)<10 else (alt_seq[:10] + "...")}'
        print(f'合并后的突变: [{o_var}]')
        self.result_var = Var(f'{self.vars[0].Chr} {Start} {End} {ref_seq} {alt_seq}', self.fasta)
        
    def __str__(self):
        return str(self.result_var)


class Var:
    """ 突变基础类，记录突变的 染色体、起始位置、终止位置、ref序列和突变序列 
    并将其进行左对齐，以进行比较
    """
    def __init__(self, var, fasta):
        self.in_var = var
        self.fasta = pysam.FastaFile(fasta)
        self.Chr, self.Start, self.End, self.Ref, self.Alt, self.Type = self.格式化(var)
        self.左对齐()

    def 格式化(self, var):
        """ # 处理 vcf 格式的插入缺失, 如: T/TC  TC/T """
        var = re.sub('[ \:\t_]+', ' ', var.strip())
        m = re.findall(r'(\d *?\- *?\d)', var)
        if m:
            var = var.replace(m[0], m[0].replace('-', ' '))
        var = re.sub(r' +', ' ', var)
        Chr, Start, End, Ref, Alt = var.strip().split(' ')
        Ref = Ref.upper()
        Alt = Alt.upper()
        Start = int(Start)
        End = int(End)
        if Ref.startswith(Alt):
            Ref = Ref[len(Alt):]
            Alt = '-'
            Start += len(Alt)
        elif Alt.startswith(Ref):
            Alt = Alt[len(Ref):]
            Ref = '-'
        elif Ref == '-':
            if Start == End - 1:
                End -= 1
        Chr = Chr.upper()
        if Chr.startswith('CHR'):
            Chr = Chr.replace('CHR', 'chr')
        else:
            Chr = f'chr{Chr}'
        if Ref == '' or Alt == '' or (Ref == '-' and Alt == '-'):
            raise ValueError(f'突变的碱基写法错误 {self.in_var}')
        if any([c not in 'ATCG-' for c in Ref+Alt]):
            raise ValueError(f'突变的碱基字母错误 {self.in_var}')
        if not any([f'chr{c}' == Chr for c in list(range(1,23)) + ['X', 'Y']]):
            raise ValueError(f'染色体有误 {Chr}')
        Type = self.突变类型判断(Ref, Alt)
        return Chr, Start, End, Ref, Alt, Type

    @staticmethod
    def 最小重复碱基(base: str):
        """ 找到输入碱基的最小重复区域，用于确定串联重复的最小单位 """
        for i in range(len(base)):
            if base.count(base[0:i+1]) == len(base) / (i + 1):
                return base[0:i+1]
        else:
            return base

    @staticmethod
    def 突变类型判断(Ref, Alt):
        """ 根据输入位置判断突变的类型 """
        if Ref == '-':
            return 'INS'
        elif Alt == '-':
            return 'DEL'
        elif len(Ref) > 1:
            return 'DELINS'
        else:
            return 'SNV'

    def 左对齐(self):
        """ DEL、INS 左对齐，注意碱基变化 """
        if self.Type in ['SNV', 'DELINS']:
            return
        var_dup_base = self.最小重复碱基(self.Alt) if self.Type == 'INS' else self.最小重复碱基(self.Ref)
        local_start = self.Start if self.Type == 'INS' else self.Start - 1
        ref_left_seq: str = self.fasta.fetch(
            region=f'{self.Chr}:{local_start-500}-{local_start}'
        ).upper()[::-1]
        dup_len = len(var_dup_base)
        反向重复碱基 = var_dup_base[::-1]
        for count in range(int(500//len(var_dup_base))):
            if ref_left_seq[count*dup_len:(count+1)*dup_len] != 反向重复碱基:
                break
        # 先对重复区域左移
        if count:
            move_len = count * dup_len
            self.Start = self.Start - move_len
            self.End = self.End - move_len
        # 相同碱基的左对齐
        if self.Type == 'DEL':
            # print(self.Start)
            left_seq: str = self.fasta.fetch(
                region=f'{self.Chr}:{self.Start-150}-{self.Start-1}'
            ).upper()[::-1]
            target_seq = self.Ref[::-1]
        elif self.Type == 'INS':
            left_seq: str = self.fasta.fetch(
                region=f'{self.Chr}:{self.Start-150}-{self.Start}'
            ).upper()[::-1]
            target_seq = self.Alt[::-1]
        mv_len = 0
        for i, base in enumerate(target_seq):
            if base == left_seq[i]:
                mv_len += 1
                continue
            break
        # 最终位置和碱基调整
        if mv_len:
            if self.Type == 'DEL':
                self.Start = self.Start - mv_len
                self.End = self.End - mv_len
                self.Ref = self.fasta.fetch(region=f'{self.Chr}:{self.Start}-{self.End}').upper()
            elif self.Type == 'INS':
                self.Start = self.Start - mv_len
                self.End = self.End - mv_len
                self.Alt = f'{left_seq[:mv_len][::-1]}{self.Alt[:-mv_len]}'

    def with_same_pos(self, in_var):
        """ 判断两个突变位置是否有交集 """
        # 插入突变不消耗位置
        if in_var.Ref == '-' or self.Ref == '-':
            return False
        return self.Start <= in_var.Start <= self.End or \
            self.Start <= in_var.End <= self.End or \
            (in_var.Start <= self.Start and \
                in_var.End >= self.End)

    def __str__(self):
        return f'[{self.Chr} {self.Start} {self.End} {self.Ref} {self.Alt}]'

    def tab_str(self):
        return f'{self.Chr}\t{self.Start}\t{self.End}\t{self.Ref}\t{self.Alt}'

